# Replication Archive for Generalizability of IR experiments beyond the US
# These functions are adaptations of the method described in Ding, Feller, and Miratrix 2015
# Some of the functions are borrowed from Coppock 2019, PSRM:
# Generalizing from Survey Experiments Conducted on Mechanical Turk: A Replication Approach

get_pos_constant_fx <- function(Y, Z, diff_in_means_vec) {
  Z_fac <- factor(Z)
  condition_names <- levels(Z_fac)
  Z_dum <- model.matrix( ~ Z_fac)[, -1, drop = FALSE]
  Y0 <- Y - Z_dum %*% diff_in_means_vec
  
  PO_mat <-
    matrix(rep(Y0, length(condition_names)), ncol = length(condition_names))
  tau_mat <- matrix(rep(c(0, diff_in_means_vec), length(Y)),
                    ncol = length(condition_names),
                    byrow = TRUE)
  PO_mat <- PO_mat + tau_mat
  colnames(PO_mat) <- condition_names
  return(PO_mat)
}

get_dif_in_means <- function(Y, Z) {
  Z_fac <- factor(Z)
  means_vec <- tapply(X = Y, INDEX = Z, FUN = mean)
  vars_vec <- tapply(X = Y, INDEX = Z, FUN = var)
  ns_vec <- tapply(X = Y, INDEX = Z, FUN = length)
  vars_over_ns_vec <- vars_vec / ns_vec
  
  diff_in_means_vec <- means_vec[-1] - means_vec[1]
  diff_in_means_ses_vec <-
    sqrt(vars_over_ns_vec[-1] + vars_over_ns_vec[1])
  diff_in_means_ci_upper_vec <-
    diff_in_means_vec + qnorm(p = 0.9999) * diff_in_means_ses_vec
  diff_in_means_ci_lower_vec <-
    diff_in_means_vec - qnorm(p = 0.9999) * diff_in_means_ses_vec
  
  df <- data.frame(
    ATEs = diff_in_means_vec,
    SEs = diff_in_means_ses_vec,
    UIs = diff_in_means_ci_upper_vec,
    LIs = diff_in_means_ci_lower_vec
  )
  return(df)
}

## Shifted kolmogorov-smirnov statistic
## Calculate KS distance between Y0 and Y1 shifted by estimated tau.
SKS.stat_AEC <- function(x, y) {
  Y1 = x
  Y0 = y
  
  Y1.star   = Y1 - mean(Y1)
  Y0.star   = Y0 - mean(Y0)
  
  unique.points = c(Y1.star, Y0.star)
  
  Fn1 = ecdf(Y1.star)
  Fn0 = ecdf(Y0.star)
  
  difference = Fn1(unique.points) - Fn0(unique.points)
  
  return(max(abs(difference)))
  
}

get_po_obs <- function(PO_mat, Z) {
  condition_names <- colnames(PO_mat)
  
  if (!all(condition_names %in% unique(Z))) {
    stop("Not all PO names are in Z names")
  }
  
  Y_obs <- rep(NA, length(Z))
  for (i in 1:length(condition_names)) {
    Y_obs[Z == condition_names[i]] <-
      PO_mat[Z == condition_names[i], condition_names[i]]
  }
  return(Y_obs)
}

seq_vectorized <-
  Vectorize(seq.default, vectorize.args = c("from", "to"))

homogenous_fx_test <-
  function(Y,
           Z,
           data,
           sims = 100,
           ksims = 25,
           CI_method = TRUE,
           verbose = TRUE) {
    dv <- deparse(substitute(Y))
    Y_na <-  eval(substitute(Y), data)
    if (!is.numeric(Y_na)) {
      stop("The outcome variable (Y) must be numeric.")
    }
    Z_na <-  eval(substitute(Z), data)
    
    Y <- Y_na[!is.na(Y_na) & !is.na(Z_na)]
    Z <- Z_na[!is.na(Y_na) & !is.na(Z_na)]
    Z <- droplevels(factor(Z))
    
    Ys <- split(x = Y, f = Z)
    Zs <- split(x = Z, f = Z)
    
    summary_df <- NULL
    bb_df <- data.frame(sim = 1:ksims)
    
    for (i in 2:length(Ys)) {
      out <-
        homogenous_fx_binary(
          Y = c(Ys[[1]], Ys[[i]]),
          Z = c(Zs[[1]], Zs[[i]]),
          sims = sims,
          ksims = ksims,
          CI_method = CI_method,
          verbose = verbose
        )
      summary_df <- rbind(summary_df, out$summary_df)
      bb_df <- cbind(bb_df, out$bb_df)
    }
    
    summary_df$condition_names <- levels(Z)[-1]
    summary_df$dv <- dv
    return(list(summary_df = summary_df, bb_df = bb_df))
  }



homogenous_fx_binary <-
  function(Y,
           Z,
           sims = 200,
           ksims = 10,
           CI_method = TRUE,
           verbose = TRUE) {
    if (length(unique(Z)) != 2) {
      stop("This test can only compare two treatments at a time.")
    }
    
    Z <- droplevels(factor(Z))
    Ys <- split(x = Y, f = Z)
    ks_obs <- sapply(Ys[-1], FUN = SKS.stat_AEC, y = Ys[[1]])
    
    d_i_m <- get_dif_in_means(Y, Z)
    
    PO_mat <-
      get_pos_constant_fx(Y, Z, diff_in_means_vec = d_i_m$ATEs)
    condition_names <- colnames(PO_mat)
    
    ks_sim <- matrix(NA, nrow = length(ks_obs), ncol = sims)
    for (i in 1:sims) {
      #Z_sim <- simple_ra(length(Z), condition_names = condition_names)
      Z_sim <- sample(Z)
      Z_sim <- factor(Z_sim, levels = levels(Z))
      Y_sim <- get_po_obs(PO_mat = PO_mat, Z = Z_sim)
      Ys_sim <- split(x = Y_sim, f = Z_sim)
      ks_sim[, i] <-
        sapply(Ys_sim[-1], FUN = SKS.stat_AEC, y = Ys_sim[[1]])
    }
    
    df <- data.frame(
      condition_names = condition_names[-1],
      ks_obs = ks_obs,
      ps = rowMeans(ks_sim > ks_obs),
      d_i_m
    )
    
    if (CI_method == TRUE) {
      ate_seqs <- seq_vectorized(from = d_i_m$LIs,
                                 to = d_i_m$UIs,
                                 length.out = ksims)
      ps_mat <- matrix(NA, nrow = ksims, ncol = length(ks_obs))
      
      colnames(ate_seqs) <- paste0(condition_names, "_ATE")[-1]
      colnames(ps_mat) <- paste0(condition_names, "_p")[-1]
      
      if (verbose) {
        pb <- txtProgressBar(min = 0,
                             max = ksims,
                             style = 3)
      }
      for (j in 1:ksims) {
        if (verbose) {
          setTxtProgressBar(pb, j)
        }
        PO_mat <-
          get_pos_constant_fx(Y, Z, diff_in_means_vec = ate_seqs[j, ])
        ks_sim <- matrix(NA, nrow = length(ks_obs), ncol = sims)
        for (i in 1:sims) {
          Z_sim <- sample(Z)
          Y_sim <- get_po_obs(PO_mat = PO_mat, Z = Z_sim)
          Ys_sim <- split(x = Y_sim, f = Z_sim)
          ks_sim[, i] <-
            sapply(Ys_sim[-1], FUN = SKS.stat_AEC, y = Ys_sim[[1]])
        }
        ps_mat[j, ] <- rowMeans(ks_sim > ks_obs)
      }
      
      if (verbose) {
        close(pb)
      }
      bb_df <- data.frame(ate_seqs, ps_mat)
      
      df$max_pval <- apply(ps_mat, 2, max)
    } else{
      bb_df <- NULL
    }
    rownames(df) <- NULL
    return(list(summary_df = df, bb_df = bb_df))
  }

#function to manually fix obvious spelling mistakes and unify country names

change_words<-function(a){
  a<-tolower(a) 
  a<-gsub("korean","korea",a)
  a<-gsub("north korea","north_korea",a)
  a<-gsub("noth korea","north_korea",a)
  a<-gsub("n korea","north_korea",a)
  a<-gsub("north koreh","north_korea",a)
  a<-gsub("north korean","north_korea",a)
  a<-gsub("northkorea","north_korea",a)
  a<-gsub("northchoria","north_korea",a)
  a<-gsub("norrh korea","north_korea",a)
  a<-gsub("south korea","south_korea",a)
  a<-gsub("south korean","south_korea",a)
  a<-gsub("south africa","south_africa",a)
  a<-gsub("u. s.a","usa",a)
  a<-gsub("united state of usa","usa",a)
  a<-gsub("united states of america","usa",a)
  a<-gsub("unites states of america","usa",a)
  a<-gsub("united states","usa",a)
  a<-gsub("united state","usa",a)
  a<-gsub("america","usa",a)
  a<-gsub("amerika","usa",a)
  a<-gsub("amireca","usa",a)
  a<-gsub("unidod states","usa",a)
  a<-gsub("u.s.","usa",a)
  a<-gsub("u.s","usa",a)
  a<-gsub("us","usa",a)
  a<-gsub("estados unitos","usa",a)
  a<-gsub("units states","usa",a)
  a<-gsub("usaa","usa",a)
  a<-gsub("usan","usa",a)
  a<-gsub("chaina","china",a)
  a<-gsub("china's's","china",a)
  a<-gsub("chinna","china",a)
  a<-gsub("cgina","china",a)
  a<-gsub("rousdia","russia",a)
  a<-gsub("rusaa","russia",a)
  a<-gsub("russian","russia",a)
  a<-gsub("russians","russia",a)
  a<-gsub("rusia","russia",a)
  a<-gsub("rusland","russia",a)
  a<-gsub("russias","russia",a)
  a<-gsub("rusdia","russia",a)
  a<-gsub("rossim","russia",a)
  a<-gsub("rusa","russia",a)
  a<-gsub("rusia","russia",a)
  a<-gsub("rassia","russia",a)
  a<-gsub("rasia","russia",a)
  a<-gsub("rzssland","russia",a)
  a<-gsub("russiabd","russia",a)
  a<-gsub("people's republic of china","china",a)
  a<-gsub("people's republic","china",a)
  a<-gsub("isreal","israel",a)
  a<-gsub("iseal","israel",a)
  a<-gsub("israeli","israel",a)
  a<-gsub("great britain","uk",a)
  a<-gsub("britain","uk",a)
  a<-gsub("britian","uk",a)
  a<-gsub("british","uk",a)
  a<-gsub("britis","uk",a)
  a<-gsub("u.k.","uk",a)
  a<-gsub("u.k","uk",a)
  a<-gsub("london","uk",a)
  a<-gsub("united kingdom","uk",a)
  a<-gsub("united kindom","uk",a)
  a<-gsub("england","uk",a)
  a<-gsub("united arab emirates","uae",a)
  a<-gsub("united arab emirate","uae",a)
  a<-gsub("united arab","uae",a)
  a<-gsub("emirates","uae",a)
  a<-gsub("ukrein","ukraine",a)
  a<-gsub("ukarine","ukraine",a)
  a<-gsub("ukrainian","ukraine",a)
  a<-gsub("ukrain","ukraine",a)
  a<-gsub("eukrain","ukraine",a)
  a<-gsub("ukrainia","ukraine",a)
  a<-gsub("ukr4aine","ukraine",a)
  a<-gsub("okrina","ukraine",a)
  a<-gsub("orkina","ukraine",a)
  a<-gsub("okranine","ukraine",a)
  a<-gsub("ukranie","ukraine",a)
  a<-gsub("ukrina","ukraine",a)
  a<-gsub("raukraine","ukraine",a)
  a<-gsub("ukain","ukraine",a)
  a<-gsub("ucrain","ukraine",a)
  a<-gsub("north and south  corea","north_korea south_korea",a)
  a<-gsub("pakistani","pakistan",a)
  a<-gsub("pakisthan","pakistan",a)
  a<-gsub("pakistantan","pakistan",a)
  a<-gsub("cameroun","cameroon",a)
  a<-gsub("cameron","cameroon",a)
  a<-gsub("cameeron","cameroon",a)
  a<-gsub("tuky","turkey",a)
  a<-gsub("egipto","egypt",a)
  a<-gsub("daash","isis",a)
  a<-gsub("yardan","jordan",a)
  a<-gsub("chille","chile",a)
  a<-gsub("paraguide","paraguay",a)
  a<-gsub("irac","iraq",a)
  a<-gsub("irak","iraq",a)
  a<-gsub("brazilian","brazil",a)
  a<-gsub("bangla desh","bangladesh",a)
  a<-gsub("canda","canada",a)
  a<-gsub("chini","canada",a)
  a<-gsub("chins","canada",a)
  a<-gsub("chana","canada",a)
  a<-gsub("indian","india",a)
  a<-gsub("french","france",a)
  a<-gsub("fra ce","france",a)
  a<-gsub("sriilanka","sri_lanka",a)
  a<-gsub("shri lanka","sri_lanka",a)
  a<-gsub("shrilanka","sri_lanka",a)
  a<-gsub("sri lanka","sri_lanka",a)
  a<-gsub("afgan","afghanistan",a)
  a<-gsub("afganastan","afghanistan",a)
  a<-gsub("twaiwan","taiwan",a)
  a<-gsub("tawain","taiwan",a)
  a<-gsub("korea democracy","south_korea",a)
  a<-gsub("south coriaya","south_korea",a)
  a<-gsub("coreea nord","north_korea",a)
  a<-gsub("dprk","north_korea",a)
  a<-gsub("k await","kuwait",a)
  a<-gsub("japon","japan",a)
  a<-gsub("japanese","japan",a)
  a<-gsub("ukrainee","ukraine",a)
  a<-gsub("ukraineia","ukraine",a)
  a<-gsub("ukine","ukraine",a)
  a<-gsub("irian","iran",a)
  a<-gsub("irano","iran",a)
  a<-gsub("airan","iran",a)
  a<-gsub("iram","iran",a)
  a<-gsub("persia","iran",a)
  a<-gsub("usaa","usa",a)
  a<-gsub("pakistantan","pakistan",a)
  a<-gsub("afghanistanistan","afghanistan",a)
  a<-gsub("afghanistanstan","afghanistan",a)
  a<-gsub("saudi arabia","saudi_arabia",a)
  a<-gsub("n. korea","north_korea",a)
  a<-gsub("north koea","north_korea",a)
  a<-gsub("anorth_korea","north_korea",a)
  a<-gsub("s korea","south_korea",a)
  a<-gsub("russiaia","russia",a)
  a<-gsub("iranian","iran",a)
  a<-gsub("korea people's democratic republic","north_korea",a)
  a<-gsub("italian","italy",a)
  a<-gsub("syrians","syria",a)
  a<-gsub("egyptians","egypt",a)
  a<-gsub("palestinians","palestine",a)
  a<-gsub("palestinian","palestine",a)
  a<-gsub("palestina","palestine",a)
  a<-gsub("lebanese","lebanon",a)
  a<-gsub("quatar","qatar",a)
  a<-gsub("ukrainea","ukraine",a)
  a<-gsub("uecraune","ukraine",a)
  a<-gsub("russialand","russia",a)
  a<-gsub("chinae","china",a)
  a<-gsub("chinw","china",a)
  a<-gsub("pak8stan","pakistan",a)
  a<-gsub("norwegian","norway",a)
  a<-gsub("argentine","argentina",a)
  a<-gsub("chine","china",a)
  a<-gsub("usaof a","usa",a)
  a<-gsub("rasiya","russia",a)
  a<-gsub("chinta","china",a)
  a<-gsub("chin a","china",a)
  a<-gsub("chinese","china",a)
  a<-gsub("germany","german",a)
  a<-gsub("german","germany",a)
  a<-gsub("germeny","germany",a)
  a<-gsub("chinase","china",a)
  
}